Integrated Gradients vs Activation Patching in GPT2-Small

Background, motivation and set up

Objective: Compare attributions using integrated gradients and activation patching, and investigate the discrepancies between the two methods.

Motivation:

  • Understand when and why do IG and AP disagree: e.g. methodological limitations, or suitability to model tasks, etc.
  • Investigate if discrepancies help uncover different hidden model behaviours
  • Understand when and why linear approximations to activation patching fail
  • Investigate limitations of using activation patching for evaluations: if results are different because of other unknown factors (not just because the method evaluated is “incorrect”)

Set-up:

We load the transformer model GPT2-Small, which has 12 layers, 12 attention heads per layer, embedding size 768 and 4 x 768 = 3,072 neurons in each feed-forward layer. We use GPT2-Small because 1) it is a relatively small transformer model which has comparable behaviour to larger SOTA models, and 2) there is a lot of interpretability literature which focuses on circuits in this model.

Code
import torch
import numpy as np

from captum.attr import LayerIntegratedGradients

from transformer_lens.utils import get_act_name, get_device
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

import seaborn as sns
import matplotlib.pyplot as plt
Code
torch.set_grad_enabled(False)

device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
Loaded pretrained model gpt2-small into HookedTransformer

Attribution for GPT2-Small

We scale up our earlier experiments to implement integrated gradients and activation patching on a larger transformer model. We use the same counterfactual inputs, based on the Indirect Object Identification task.

Code
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

clean_input = model.to_tokens(clean_prompt)
corrupted_input = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" John", incorrect_answer=" Mary"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

clean_logits, clean_cache = model.run_with_cache(clean_input)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_input)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
Clean logit difference: 4.276
Corrupted logit difference: -2.738

Integrated Gradients

Code
def run_from_layer_fn(x, original_input, prev_layer, reset_hooks_end=True):
    # Force the layer before the target layer to output the given values, i.e. pass the given input into the target layer
    # original_input value does not matter; useful to keep shapes nice, but its activations will be overwritten
    
    def fwd_hook(act, hook):
        x.requires_grad_(True)
        return x
    
    logits = model.run_with_hooks(
        original_input,
        fwd_hooks=[(prev_layer.name, fwd_hook)],
        reset_hooks_end=reset_hooks_end
    )
    logit_diff = logits_to_logit_diff(logits).unsqueeze(0)
    return logit_diff

def compute_layer_to_output_attributions(original_input, layer_input, layer_baseline, target_layer, prev_layer):
    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(x, original_input, prev_layer)
    # Attribute to the target_layer's output
    ig_embed = LayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)
    attributions, approximation_error = ig_embed.attribute(inputs=layer_input,
                                                    baselines=layer_baseline, 
                                                    attribute_to_layer_input=False,
                                                    return_convergence_delta=True)
    print(f"\nError (delta) for {target_layer.name} attribution: {approximation_error.item()}")
    return attributions
Code
mlp_ig_zero_results = torch.load("mlp_ig_zero_results.pt")
attn_ig_zero_results = torch.load("attn_ig_zero_results.pt")
Code
# Gradient attribution using the zero baseline, as originally recommended
mlp_ig_zero_results = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp)
attn_ig_zero_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

# Calculate integrated gradients for each layer
for layer in range(model.cfg.n_layers):
    # Gradient attribution on heads
    hook_name = get_act_name("result", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("z", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook]
    layer_corrupt_input = torch.zeros_like(corrupted_cache[prev_layer_hook])

    attributions = compute_layer_to_output_attributions(clean_input, layer_corrupt_input, layer_clean_input, target_layer, prev_layer) # shape [1, seq_len, d_head, d_model]
    # Calculate attribution score based on mean over each embedding, for each token
    print(attributions.shape)
    per_token_score = attributions.mean(dim=3)
    score = per_token_score.mean(dim=1)
    attn_ig_zero_results[layer] = score

    # Gradient attribution on MLP neurons
    hook_name = get_act_name("post", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("mlp_in", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook]
    layer_corrupt_input = torch.zeros_like(corrupted_cache[prev_layer_hook])
    
    attributions = compute_layer_to_output_attributions(clean_input, layer_corrupt_input, layer_clean_input, target_layer, prev_layer) # shape [1, seq_len, d_model]
    print(attributions.shape)
    score = attributions.mean(dim=1)
    mlp_ig_zero_results[layer] = score

torch.save(mlp_ig_zero_results, "mlp_ig_zero_results.pt")
torch.save(attn_ig_zero_results, "attn_ig_zero_results.pt")

Error (delta) for blocks.0.attn.hook_result attribution: 1.598672866821289
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.0.mlp.hook_post attribution: 6.3065714836120605
torch.Size([1, 17, 3072])

Error (delta) for blocks.1.attn.hook_result attribution: 0.07180098444223404
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.1.mlp.hook_post attribution: 0.6716213226318359
torch.Size([1, 17, 3072])

Error (delta) for blocks.2.attn.hook_result attribution: 0.22619788348674774
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.2.mlp.hook_post attribution: 1.6502611637115479
torch.Size([1, 17, 3072])

Error (delta) for blocks.3.attn.hook_result attribution: 0.7059890031814575
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.3.mlp.hook_post attribution: 3.412715435028076
torch.Size([1, 17, 3072])

Error (delta) for blocks.4.attn.hook_result attribution: 59.02961730957031
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.4.mlp.hook_post attribution: 0.9958128929138184
torch.Size([1, 17, 3072])

Error (delta) for blocks.5.attn.hook_result attribution: 1.1134357452392578
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.5.mlp.hook_post attribution: 3.137509822845459
torch.Size([1, 17, 3072])

Error (delta) for blocks.6.attn.hook_result attribution: 0.8119007349014282
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.6.mlp.hook_post attribution: 2.4578588008880615
torch.Size([1, 17, 3072])

Error (delta) for blocks.7.attn.hook_result attribution: 1.150354027748108
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.7.mlp.hook_post attribution: 0.33406609296798706
torch.Size([1, 17, 3072])

Error (delta) for blocks.8.attn.hook_result attribution: 1.5636792182922363
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.8.mlp.hook_post attribution: -0.8596433401107788
torch.Size([1, 17, 3072])

Error (delta) for blocks.9.attn.hook_result attribution: -0.9515517354011536
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.9.mlp.hook_post attribution: 0.7980309128761292
torch.Size([1, 17, 3072])

Error (delta) for blocks.10.attn.hook_result attribution: -0.5746663808822632
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.10.mlp.hook_post attribution: 0.07278364896774292
torch.Size([1, 17, 3072])

Error (delta) for blocks.11.attn.hook_result attribution: -0.17211738228797913
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.11.mlp.hook_post attribution: -0.7840024828910828
torch.Size([1, 17, 3072])
Code
bound = max(torch.max(mlp_ig_zero_results), abs(torch.min(mlp_ig_zero_results)))

plt.figure(figsize=(75, 10))
plt.imshow(mlp_ig_zero_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound, aspect="auto")
plt.title("MLP Neuron Gradient Attribution (Integrated Gradients)")
plt.xticks(np.arange(0, model.cfg.d_mlp, 250))
plt.xlabel("Neuron Index")
plt.yticks(list(range(model.cfg.n_layers)))
plt.ylabel("Layer")
plt.colorbar()
plt.show()

Code
bound = max(torch.max(attn_ig_zero_results), abs(torch.min(attn_ig_zero_results)))

plt.figure(figsize=(10, 5))
plt.imshow(attn_ig_zero_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound)
plt.title("Attention Head Gradient Attribution (Integrated Gradients)")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar()
plt.show()

Activation Patching

Code
def patch_neuron_hook(activations: torch.Tensor, hook: HookPoint, cache: ActivationCache, neuron_idx: int):
    # Replace the activations for the target neuron with activations from the cached run.
    cached_activations = cache[hook.name]
    activations[:, :, neuron_idx] = cached_activations[:, :, neuron_idx]
    return activations

def patch_attn_hook(activations: torch.Tensor, hook: HookPoint, cache: ActivationCache, head_idx: int):
    # Replace the activations for the target attention head with activations from the cached run.
    cached_activations = cache[hook.name]
    activations[:, :, head_idx, :] = cached_activations[:, :, head_idx, :]
    return activations

baseline_diff = (clean_logit_diff - corrupted_logit_diff).item()
Code
mlp_patch_results = torch.load("mlp_patch_results.pt")
attn_patch_results = torch.load("attn_patch_results.pt")
Code
class StopExecution(Exception):
    def _render_traceback_(self):
        return []
    
# Check if we have run activation patching already (expensive)
try:
    mlp_patch_results = torch.load("mlp_patch_results.pt")
    attn_patch_results = torch.load("attn_patch_results.pt")
    raise StopExecution
except FileNotFoundError:
    mlp_patch_results = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp)
    attn_patch_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

    for layer in range(model.cfg.n_layers):
        # Activation patching on heads
        for head in range(model.cfg.n_heads):
            hook_name = get_act_name("result", layer)
            temp_hook = lambda act, hook: patch_attn_hook(act, hook, corrupted_cache, head)

            with model.hooks(fwd_hooks=[(hook_name, temp_hook)]):
                patched_logits = model(clean_input)

            patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
            # Normalise result by clean and corrupted logit difference
            attn_patch_results[layer, head] = (patched_logit_diff - clean_logit_diff) / baseline_diff

        # Activation patching on MLP neurons
        for neuron in range(model.cfg.d_mlp):
            hook_name = get_act_name("post", layer)
            temp_hook = lambda act, hook: patch_neuron_hook(act, hook, corrupted_cache, neuron)
            
            with model.hooks(fwd_hooks=[(hook_name, temp_hook)]):
                patched_logits = model(clean_input)

            patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
            # Normalise result by clean and corrupted logit difference
            mlp_patch_results[layer, neuron] = (patched_logit_diff - clean_logit_diff) / baseline_diff
    
    torch.save(mlp_patch_results, "mlp_patch_results.pt")
    torch.save(attn_patch_results, "attn_patch_results.pt")
Code
bound = max(torch.max(mlp_patch_results), abs(torch.min(mlp_patch_results)))

plt.figure(figsize=(75, 10))
plt.imshow(mlp_patch_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound, aspect="auto")
plt.title("MLP Neuron Gradient Attribution (Activation Patching)")
plt.xticks(np.arange(0, model.cfg.d_mlp, 250))
plt.xlabel("Neuron Index")
plt.yticks(list(range(model.cfg.n_layers)))
plt.ylabel("Layer")
plt.colorbar()
plt.show()

Code
bound = max(torch.max(attn_patch_results), abs(torch.min(attn_patch_results)))

plt.figure(figsize=(10, 5))
plt.imshow(attn_patch_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound)
plt.title("Attention Head Gradient Attribution (Activation Patching)")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar()
plt.show()

Analysis of initial attribution methods

Code
# Plot the attribution scores against each other. Correlation: y = x.

x = mlp_ig_zero_results.flatten().numpy()
y = mlp_patch_results.flatten().numpy()

sns.regplot(x=x, y=y)
plt.xlabel("Integrated Gradients MLP Attribution Scores")
plt.ylabel("Activation Patching MLP Attribution Scores")
plt.show()

print(f"Correlation coefficient between IG and AP attributions for neurons: {np.corrcoef(x, y)[0, 1]}")

x = attn_ig_zero_results.flatten().numpy()
y = attn_patch_results.flatten().numpy()

sns.regplot(x=x, y=y)
plt.xlabel("Integrated Gradients Attention Attribution Scores")
plt.ylabel("Causal Tracing Attention Attribution Scores")
plt.show()

print(f"Correlation coefficient between IG and AP attributions for attention: {np.corrcoef(x, y)[0, 1]}")

Correlation coefficient between IG and AP attributions for neurons: 0.2350942982779252

Correlation coefficient between IG and AP attributions for attention: 0.00930087754303134
Code
def get_top_k_by_abs(data, k):
    _, indices = torch.topk(data.flatten().abs(), k)
    top_k_values = torch.gather(data.flatten(), 0, indices)
    formatted_indices = []
    for idx in indices:
        layer = idx // model.cfg.d_mlp
        neuron_pos = idx % model.cfg.d_mlp
        formatted_indices.append([layer, neuron_pos])
    return torch.tensor(formatted_indices), top_k_values

def get_attributions_above_threshold(data, percentile):
    threshold = torch.min(data) + percentile * (torch.max(data) - torch.min(data))
    masked_data = torch.where(data > threshold, data, 0)
    nonzero_indices = torch.nonzero(masked_data)
    return nonzero_indices, masked_data

top_mlp_ig_zero_indices, top_mlp_ig_zero_results = get_top_k_by_abs(mlp_ig_zero_results, 30)
top_mlp_patch_indices, top_mlp_patch_results = get_top_k_by_abs(mlp_patch_results, 30)

top_mlp_ig_zero_sets = set([tuple(t.tolist()) for t in top_mlp_ig_zero_indices])
top_mlp_patch_sets = set([tuple(t.tolist()) for t in top_mlp_patch_indices])

intersection = top_mlp_ig_zero_sets.intersection(top_mlp_patch_sets)
union = top_mlp_ig_zero_sets.union(top_mlp_patch_sets)
jaccard = len(intersection) / len(union)

print(f"Jaccard score for MLP neurons: {jaccard}")
Jaccard score for MLP neurons: 0.1111111111111111
Code
from sklearn.preprocessing import MaxAbsScaler

mlp_ig_zero_results_1d = mlp_ig_zero_results.flatten().numpy()
mlp_patch_results_1d = mlp_patch_results.flatten().numpy()

# Mean difference plot with scaled data

scaled_mlp_ig_results_1d = MaxAbsScaler().fit_transform(mlp_ig_zero_results_1d.reshape(-1, 1))
scaled_mlp_patch_results_1d = MaxAbsScaler().fit_transform(mlp_patch_results_1d.reshape(-1, 1))

mean = np.mean([scaled_mlp_ig_results_1d, scaled_mlp_patch_results_1d], axis=0)
diff = scaled_mlp_patch_results_1d - scaled_mlp_ig_results_1d
md = np.mean(diff) # Mean of the difference
sd = np.std(diff, axis=0) # Standard deviation of the difference

plt.figure(figsize=(10, 6))
sns.regplot(x=mean, y=diff, fit_reg=True, scatter=True)
plt.axhline(md, color='gray', linestyle='--', label="Mean difference")
plt.axhline(md + 1.96*sd, color='pink', linestyle='--', label="1.96 SD of difference")
plt.axhline(md - 1.96*sd, color='lightblue', linestyle='--', label="-1.96 SD of difference")
plt.xlabel("Mean of attribution scores per neuron")
plt.ylabel("Difference (activation patching - integrated gradients) per neuron")
plt.title("Mean-difference plot of scaled attribution scores from integrated gradients and activation patching")
plt.legend()
plt.show()

Code
from sklearn.preprocessing import MaxAbsScaler

scaled_attn_ig_zero_results = MaxAbsScaler().fit_transform(attn_ig_zero_results)
scaled_attn_patch_results = MaxAbsScaler().fit_transform(attn_patch_results)

diff_attn_results = scaled_attn_ig_zero_results - scaled_attn_patch_results
diff_attn_results_abs = np.abs(scaled_attn_ig_zero_results) - np.abs(scaled_attn_patch_results)

plt.figure(figsize=(10,10))
plt.subplot(1, 2, 1)
plt.imshow(diff_attn_results, cmap="RdBu")
plt.title("Difference in attributions for attention heads")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar(orientation="horizontal")

plt.subplot(1, 2, 2)
plt.imshow(diff_attn_results_abs, cmap="RdBu")
plt.title("Difference in (absolute) attributions for attention heads")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar(orientation="horizontal")
plt.tight_layout()
plt.show()

Comparable baselines

Hypothesis: One possible reason for the discrepancy between patching and IG is that the range of activations tested may be from different distributions.

Both gradient methods rely on counterfactual reasoning. IG computes the integral between some baseline (which produces zero output) and given input, whereas causal tracing computes the logit difference between two counterfactual inputs. If the counterfactuals used are different, then this could cause a discrepancy.

To evaluate this hypothesis, we compute IG and AP on GPT2-Small with the same counterfactual inputs.

Code
mlp_ig_results = torch.load("mlp_ig_results.pt")
attn_ig_results = torch.load("attn_ig_results.pt")
Code
# Gradient attribution for neurons in MLP layers
mlp_ig_results = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp)
# Gradient attribution for attention heads
attn_ig_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

# Calculate integrated gradients for each layer
for layer in range(model.cfg.n_layers):
    # Gradient attribution on heads
    hook_name = get_act_name("result", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("z", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook]
    layer_corrupt_input = corrupted_cache[prev_layer_hook]

    attributions = compute_layer_to_output_attributions(clean_input, layer_corrupt_input, layer_clean_input, target_layer, prev_layer) # shape [1, seq_len, d_head, d_model]
    # Calculate attribution score based on mean over each embedding, for each token
    print(attributions.shape)
    per_token_score = attributions.mean(dim=3)
    score = per_token_score.mean(dim=1)
    attn_ig_results[layer] = score

    # Gradient attribution on MLP neurons
    hook_name = get_act_name("post", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("mlp_in", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook]
    layer_corrupt_input = corrupted_cache[prev_layer_hook]
    
    attributions = compute_layer_to_output_attributions(clean_input, layer_corrupt_input, layer_clean_input, target_layer, prev_layer) # shape [1, seq_len, d_model]
    print(attributions.shape)
    score = attributions.mean(dim=1)
    mlp_ig_results[layer] = score

torch.save(mlp_ig_results, "mlp_ig_results.pt")
torch.save(attn_ig_results, "attn_ig_results.pt")

Error (delta) for blocks.0.attn.hook_result attribution: -0.08010423183441162
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.0.mlp.hook_post attribution: 5.367663860321045
torch.Size([1, 17, 3072])

Error (delta) for blocks.1.attn.hook_result attribution: 0.05616430938243866
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.1.mlp.hook_post attribution: -0.10697655379772186
torch.Size([1, 17, 3072])

Error (delta) for blocks.2.attn.hook_result attribution: -0.012761879712343216
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.2.mlp.hook_post attribution: -0.11945552378892899
torch.Size([1, 17, 3072])

Error (delta) for blocks.3.attn.hook_result attribution: 0.2565889358520508
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.3.mlp.hook_post attribution: 0.1360594630241394
torch.Size([1, 17, 3072])

Error (delta) for blocks.4.attn.hook_result attribution: 0.051070213317871094
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.4.mlp.hook_post attribution: -0.08819704502820969
torch.Size([1, 17, 3072])

Error (delta) for blocks.5.attn.hook_result attribution: 0.3684248626232147
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.5.mlp.hook_post attribution: 0.24765437841415405
torch.Size([1, 17, 3072])

Error (delta) for blocks.6.attn.hook_result attribution: 0.3670154809951782
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.6.mlp.hook_post attribution: -0.040538765490055084
torch.Size([1, 17, 3072])

Error (delta) for blocks.7.attn.hook_result attribution: 1.3272550106048584
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.7.mlp.hook_post attribution: 0.45815184712409973
torch.Size([1, 17, 3072])

Error (delta) for blocks.8.attn.hook_result attribution: 2.3821561336517334
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.8.mlp.hook_post attribution: 0.12053033709526062
torch.Size([1, 17, 3072])

Error (delta) for blocks.9.attn.hook_result attribution: 1.4569836854934692
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.9.mlp.hook_post attribution: 0.6440849304199219
torch.Size([1, 17, 3072])

Error (delta) for blocks.10.attn.hook_result attribution: -1.0445181131362915
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.10.mlp.hook_post attribution: 0.45244646072387695
torch.Size([1, 17, 3072])

Error (delta) for blocks.11.attn.hook_result attribution: -1.987703800201416
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.11.mlp.hook_post attribution: 0.400459885597229
torch.Size([1, 17, 3072])
Code
bound = max(torch.max(mlp_ig_results), abs(torch.min(mlp_ig_results)))

plt.figure(figsize=(75, 10))
plt.imshow(mlp_ig_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound, aspect="auto")
plt.title("MLP Neuron Gradient Attribution (Integrated Gradients) with Corrupt Baseline")
plt.xticks(np.arange(0, model.cfg.d_mlp, 250))
plt.xlabel("Neuron Index")
plt.yticks(list(range(model.cfg.n_layers)))
plt.ylabel("Layer")
plt.colorbar()
plt.show()

Code
bound = max(torch.max(attn_ig_results), abs(torch.min(attn_ig_results)))

plt.figure(figsize=(10, 5))
plt.imshow(attn_ig_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound)
plt.title("Attention Head Gradient Attribution (Integrated Gradients) with Corrupt Baseline")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar()
plt.show()

Analysis of comparable baselines

Code
# Plot the attribution scores against each other. Correlation: y = x.

x = mlp_ig_results.flatten().numpy()
y = mlp_patch_results.flatten().numpy()

sns.regplot(x=x, y=y)
plt.xlabel("Integrated Gradients (Corrupt Baseline) MLP Attribution Scores")
plt.ylabel("Activation Patching MLP Attribution Scores")
plt.show()

print(f"Correlation coefficient between IG with corrupted baseline and AP attributions for neurons: {np.corrcoef(x, y)[0, 1]}")

x = attn_ig_results.flatten().numpy()
y = attn_patch_results.flatten().numpy()

sns.regplot(x=x, y=y)
plt.xlabel("Integrated Gradients (Corrupt Baseline) Attention Attribution Scores")
plt.ylabel("Causal Tracing Attention Attribution Scores")
plt.show()

print(f"Correlation coefficient between IG with corrupted baseline and AP attributions for attention: {np.corrcoef(x, y)[0, 1]}")

Correlation coefficient between IG with corrupted baseline and AP attributions for neurons: 0.9852227509307566

Correlation coefficient between IG with corrupted baseline and AP attributions for attention: 0.9547628738711134

The correlation between attribution scores for MLP neurons and attention heads is extremely high! This indicates that, with the same baseline, both methods obtain very similar attribution scores.

Code
def get_top_k_by_abs(data, k):
    _, indices = torch.topk(data.flatten().abs(), k)
    top_k_values = torch.gather(data.flatten(), 0, indices)
    formatted_indices = []
    for idx in indices:
        layer = idx // model.cfg.d_mlp
        neuron_pos = idx % model.cfg.d_mlp
        formatted_indices.append([layer, neuron_pos])
    return torch.tensor(formatted_indices), top_k_values

def get_attributions_above_threshold(data, percentile):
    threshold = torch.min(data) + percentile * (torch.max(data) - torch.min(data))
    masked_data = torch.where(data > threshold, data, 0)
    nonzero_indices = torch.nonzero(masked_data)
    return nonzero_indices, masked_data

top_mlp_ig_indices, top_mlp_ig_results = get_top_k_by_abs(mlp_ig_results, 30)
top_mlp_patch_indices, top_mlp_patch_results = get_top_k_by_abs(mlp_patch_results, 30)

top_mlp_ig_sets = set([tuple(t.tolist()) for t in top_mlp_ig_indices])
top_mlp_patch_sets = set([tuple(t.tolist()) for t in top_mlp_patch_indices])

intersection = top_mlp_ig_sets.intersection(top_mlp_patch_sets)
union = top_mlp_ig_sets.union(top_mlp_patch_sets)
jaccard = len(intersection) / len(union)

print(f"Jaccard score for MLP neurons: {jaccard}")
Jaccard score for MLP neurons: 0.875
Code
from sklearn.preprocessing import MaxAbsScaler

mlp_ig_results_1d = mlp_ig_results.flatten().numpy()
mlp_patch_results_1d = mlp_patch_results.flatten().numpy()

# Mean difference plot with scaled data

scaled_mlp_ig_results_1d = MaxAbsScaler().fit_transform(mlp_ig_results_1d.reshape(-1, 1))
scaled_mlp_patch_results_1d = MaxAbsScaler().fit_transform(mlp_patch_results_1d.reshape(-1, 1))

mean = np.mean([scaled_mlp_ig_results_1d, scaled_mlp_patch_results_1d], axis=0)
diff = scaled_mlp_patch_results_1d - scaled_mlp_ig_results_1d
md = np.mean(diff) # Mean of the difference
sd = np.std(diff, axis=0) # Standard deviation of the difference

plt.figure(figsize=(10, 6))
sns.regplot(x=mean, y=diff, fit_reg=True, scatter=True)
plt.axhline(md, color='gray', linestyle='--', label="Mean difference")
plt.axhline(md + 1.96*sd, color='pink', linestyle='--', label="1.96 SD of difference")
plt.axhline(md - 1.96*sd, color='lightblue', linestyle='--', label="-1.96 SD of difference")
plt.xlabel("Mean of attribution scores per neuron")
plt.ylabel("Difference (activation patching - integrated gradients) per neuron")
plt.title("Mean-difference plot of scaled attribution scores from integrated gradients and activation patching")
plt.legend()
plt.show()

The mean difference plot seems to suggest that there is still some proportional bias. The difference between activation patching scores and integrated gradients scores increases as the attribution score deviates from 0. Integrated gradients seems to estimate more extreme attribution scores than activation patching.

Code
from sklearn.preprocessing import MaxAbsScaler, StandardScaler, MinMaxScaler, RobustScaler

scaled_attn_ig_results = attn_ig_results * 1e5
scaled_attn_patch_results = attn_patch_results

plt.figure(figsize=(10,10))
plt.subplot(2, 2, 1)
plt.imshow(scaled_attn_ig_results, cmap="RdBu", vmin=-0.4, vmax=0.4)
plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))
plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))
plt.colorbar(orientation="horizontal")

plt.figure(figsize=(10,10))
plt.subplot(2, 2, 1)
plt.imshow(scaled_attn_patch_results, cmap="RdBu", vmin=-0.4, vmax=0.4)
plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))
plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))
plt.colorbar(orientation="horizontal")

diff_attn_results = scaled_attn_ig_results - scaled_attn_patch_results
diff_attn_results_abs = np.abs(scaled_attn_ig_results) - np.abs(scaled_attn_patch_results)

plt.figure(figsize=(10,10))
plt.subplot(1, 2, 1)
plt.imshow(diff_attn_results, cmap="RdBu", vmin=-0.2, vmax=0.2)
plt.title("Difference in attributions for attention heads")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar(orientation="horizontal")

plt.subplot(1, 2, 2)
plt.imshow(diff_attn_results_abs, cmap="RdBu", vmin=-0.2, vmax=0.2)
plt.title("Difference in (absolute) attributions for attention heads")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar(orientation="horizontal")
plt.tight_layout()
plt.show()

Remaining questions include:

  • Although both methods are aligned when the baselines are the same, this doesn’t mean that they capture the most faithful attribution scores. For instance, if we change the baseline (which is arbitrarily set to some counterfactual value), we could get different components in the circuit. How do we select the best baselines such that faithful circuits are highlighted?

  • There are still some discrepancies in attribution scores, particularly for attention heads. What could be the cause of different attention head attribution scores?

Comparison to IOI circuit

The attention heads highlighted in the original IOI paper seem to correspond with the attention heads highlighted by both methods.

ioi_diagram

Investigation of discrepancies

General-purpose components

Hypothesis: the components which are highlighted more strongly by integrated gradients are important attention heads, which are used generically in both the clean run and corrupted run.

  • They are not detected as strongly by activation patching, which only takes the difference in logits, i.e. highlights components which are needed for the corrupted run, but not the clean run.
Code
import json

class IOIDataset:

    def __init__(self, src_path: str):
        with open(src_path) as f:
            self.data = json.load(f)
        
    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            prompts_answers = [(d['prompt'], d['answer']) for d in self.data[idx]]
            return prompts_answers
        return (self.data[idx]['prompt'], self.data[idx]['answer'])

    def to(self, device):
        self.data = self.data.to(device)
        return self
Code
ioi_dataset = IOIDataset("ioi_dataset.json")[:10_000]

Experiment 1: zero ablation

To test this, we can ablate the components which have statistically significant attribution scores outside the limits of agreement. If they affect the performance, this shows that the components are necessary, even though they are not picked up by activation patching.

Code
def evaluate_ioi_performance(ioi_dataset: IOIDataset, model: HookedTransformer):
    num_correct = 0
    num_eval = 0
    for prompt, answer in ioi_dataset:
        if num_eval % 50 == 0:
            print(f"Evaluating prompt {num_eval}")
        outputs = model.generate(input=prompt, max_new_tokens=3, do_sample=False, verbose=False)
        generated_answer = outputs.removeprefix(prompt).split()[0]
        if answer in generated_answer:
            num_correct += 1
        num_eval += 1
    return num_correct / num_eval
Code
# Measure baseline performance of model on IOI task
baseline_performance = evaluate_ioi_performance(ioi_dataset, model)
print(baseline_performance)
Moving model to device:  cpu
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
0.7864
Code
# Identify statistically significant outlier components

diff = np.abs(scaled_attn_patch_results - scaled_attn_ig_results)
diff_std = np.std(diff.numpy())

print(f"Standard deviation of differences: {diff_std}")

attn_outliers = []
for layer in range(model.cfg.n_layers):
    for head_idx in range(model.cfg.n_heads):
        if diff[layer, head_idx] > 1.96*diff_std:
            attn_outliers.append((layer, head_idx))

print(attn_outliers)
Standard deviation of differences: 0.020343296229839325
[(7, 3), (8, 10), (9, 6), (9, 9), (10, 6), (11, 10)]
Code
# Ablate components: zero ablation

all_performance_scores = []

for layer, head_idx in attn_outliers:
    attn_hook = get_act_name("result", layer)

    def ablate_hook(activations, hook):
        activations[:, :, head_idx, :] = 0
        return activations

    with model.hooks(fwd_hooks=[(attn_hook, ablate_hook)]):
        performance = evaluate_ioi_performance(ioi_dataset, model)
        all_performance_scores.append(performance)
        print(f"Performance after ablating attention head {(layer, head_idx)}: {performance}")

# TODO: Mean ablation, random ablation
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after ablating attention head (7, 3): 0.6981
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after ablating attention head (8, 10): 0.7877
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after ablating attention head (9, 6): 0.8052
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after ablating attention head (9, 9): 0.8009
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after ablating attention head (10, 6): 0.7245
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after ablating attention head (11, 10): 0.9361
Code
print(baseline_performance)
print(all_performance_scores)
0.7864
[0.6981, 0.7877, 0.8052, 0.8009, 0.7245, 0.9361]
Code
np.save("all_performance_scores.npy", all_performance_scores)
Code
for layer, idx in attn_outliers:
    print(f"Attention head {(layer, idx)}")
    ig_score = attn_ig_results[layer, idx]
    patch_score = attn_patch_results[layer, idx]
    print(f"IG score: {ig_score:.5f}, AP score: {patch_score:.5f}\n")
Attention head (7, 3)
IG score: -0.00000, AP score: -0.10789

Attention head (8, 10)
IG score: -0.00000, AP score: -0.14233

Attention head (9, 6)
IG score: -0.00000, AP score: 0.00889

Attention head (9, 9)
IG score: -0.00000, AP score: -0.11382

Attention head (10, 6)
IG score: -0.00000, AP score: -0.07559

Attention head (11, 10)
IG score: 0.00000, AP score: 0.21538
Code
plt.title("Model performance after zero ablation of attention head outliers")
plt.xlabel("Ablated attention head position")
plt.ylabel("Model performance on IOI tasks")

xs = ["None"] + [str(t) for t in attn_outliers]
ys = [baseline_performance] + all_performance_scores

plt.bar(xs, ys)
plt.show()

Code
# Correlation between difference in attribution scores and difference in performance

performance_differences = []
ig_outlier_scores = []
ap_outlier_scores = []
score_differences = []

for i in range(len(all_performance_scores)):
    performance_differences.append(all_performance_scores[i] - baseline_performance)
    layer, attn_idx = attn_outliers[i]
    ap_outlier_scores.append(attn_patch_results[layer, attn_idx])
    ig_outlier_scores.append(attn_ig_results[layer, attn_idx])
    score_diff = attn_patch_results[layer, attn_idx] - attn_ig_results[layer, attn_idx]
    score_differences.append(score_diff)
Code
sns.regplot(x=score_differences, y=performance_differences)
plt.ylabel("Difference between performance and baseline performance")
plt.xlabel("Difference between patching and IG attribution scores")
plt.show()

print(f"Correlation coefficient between attribution score differences and performance score differences: {np.corrcoef(score_differences, performance_differences)[0, 1]}")

Correlation coefficient between attribution score differences and performance score differences: 0.8298786622776018
Code
# Correlation between attribution scores and performance change

sns.regplot(x=ig_outlier_scores, y=performance_differences)
plt.ylabel("Difference between performance and baseline performance")
plt.xlabel("IG attribution scores")
plt.show()

print(f"Correlation coefficient between IG attribution score and performance score differences: {np.corrcoef(ig_outlier_scores, performance_differences)[0, 1]}")

sns.regplot(x=ap_outlier_scores, y=performance_differences)
plt.ylabel("Difference between performance and baseline performance")
plt.xlabel("Activation Patching attribution scores")
plt.show()

print(f"Correlation coefficient between IG attribution score and performance score differences: {np.corrcoef(ap_outlier_scores, performance_differences)[0, 1]}")

Correlation coefficient between IG attribution score and performance score differences: 0.7296915879554248

Correlation coefficient between IG attribution score and performance score differences: 0.8298774188113335

Of all the outliers, head (9, 6) is the only one which is strongly highlighted by integrated gradients, but not by activation patching. The other outliers have larger attribution scores assigned by integrated gradients compared to activation patching, but are highlighted by both methods.

  • Ablating head (9, 6) does not have a strong effect on the performance.
    • Conclusion: components which are only identified by integrated gradients may not be important for the specific task.
  • Interestingly, ablating heads identified as moderately important by activation patching (e.g. (8, 10), (9, 6), and (9, 9)) do not have significant impact on the performance either.
    • Conclusion: neither method identifies the minimal set of important attention heads.
    • Comparison to original IOI paper: under mean ablation, these heads (and 9.6) are highlighted and impact performance more noticeably.
  • There is not really a strong pattern / correlation between components which have higher attribution scores under IG or AP, and impact on performance.

Experiment 2: Mean ablation

Instead of using zero ablation, we use mean ablation to study the effect of a component’s removal on the model’s performance.

Code
import random

# Get mean activations
model = model.to("cpu")

attn_outlier_hooks = [get_act_name("result", layer_idx) for layer_idx, _ in attn_outliers]

random_prompts = random.sample(ioi_dataset, 100)
prompts_tokens = model.to_tokens([p for p, _ in random_prompts])
_, prompt_cache = model.run_with_cache(prompts_tokens, names_filter=lambda x: x in attn_outlier_hooks)

mean_activations = {}
for key in prompt_cache.keys():
    mean_values_over_prompts = torch.mean(prompt_cache[key], dim=0)
    mean_activations[key] = torch.mean(mean_values_over_prompts, dim=0)
Moving model to device:  cpu
Code
# Ablate components: mean ablation

all_performance_scores_mean_ablation = []

for layer, head_idx in attn_outliers:
    attn_hook = get_act_name("result", layer)

    def ablate_hook(activations, hook):
        mean_hook_acts = mean_activations[hook.name]
        activations[:, :, head_idx, :] = mean_hook_acts[head_idx]
        return activations

    with model.hooks(fwd_hooks=[(attn_hook, ablate_hook)]):
        performance = evaluate_ioi_performance(ioi_dataset, model)
        all_performance_scores_mean_ablation.append(performance)
        print(f"Performance after mean ablating attention head {(layer, head_idx)}: {performance}")
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after mean ablating attention head (7, 3): 0.712
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after mean ablating attention head (8, 10): 0.7912
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after mean ablating attention head (9, 6): 0.8386
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after mean ablating attention head (9, 9): 0.7772
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after mean ablating attention head (10, 6): 0.7141
Evaluating prompt 0
Evaluating prompt 50
Evaluating prompt 100
Evaluating prompt 150
Evaluating prompt 200
Evaluating prompt 250
Evaluating prompt 300
Evaluating prompt 350
Evaluating prompt 400
Evaluating prompt 450
Evaluating prompt 500
Evaluating prompt 550
Evaluating prompt 600
Evaluating prompt 650
Evaluating prompt 700
Evaluating prompt 750
Evaluating prompt 800
Evaluating prompt 850
Evaluating prompt 900
Evaluating prompt 950
Evaluating prompt 1000
Evaluating prompt 1050
Evaluating prompt 1100
Evaluating prompt 1150
Evaluating prompt 1200
Evaluating prompt 1250
Evaluating prompt 1300
Evaluating prompt 1350
Evaluating prompt 1400
Evaluating prompt 1450
Evaluating prompt 1500
Evaluating prompt 1550
Evaluating prompt 1600
Evaluating prompt 1650
Evaluating prompt 1700
Evaluating prompt 1750
Evaluating prompt 1800
Evaluating prompt 1850
Evaluating prompt 1900
Evaluating prompt 1950
Evaluating prompt 2000
Evaluating prompt 2050
Evaluating prompt 2100
Evaluating prompt 2150
Evaluating prompt 2200
Evaluating prompt 2250
Evaluating prompt 2300
Evaluating prompt 2350
Evaluating prompt 2400
Evaluating prompt 2450
Evaluating prompt 2500
Evaluating prompt 2550
Evaluating prompt 2600
Evaluating prompt 2650
Evaluating prompt 2700
Evaluating prompt 2750
Evaluating prompt 2800
Evaluating prompt 2850
Evaluating prompt 2900
Evaluating prompt 2950
Evaluating prompt 3000
Evaluating prompt 3050
Evaluating prompt 3100
Evaluating prompt 3150
Evaluating prompt 3200
Evaluating prompt 3250
Evaluating prompt 3300
Evaluating prompt 3350
Evaluating prompt 3400
Evaluating prompt 3450
Evaluating prompt 3500
Evaluating prompt 3550
Evaluating prompt 3600
Evaluating prompt 3650
Evaluating prompt 3700
Evaluating prompt 3750
Evaluating prompt 3800
Evaluating prompt 3850
Evaluating prompt 3900
Evaluating prompt 3950
Evaluating prompt 4000
Evaluating prompt 4050
Evaluating prompt 4100
Evaluating prompt 4150
Evaluating prompt 4200
Evaluating prompt 4250
Evaluating prompt 4300
Evaluating prompt 4350
Evaluating prompt 4400
Evaluating prompt 4450
Evaluating prompt 4500
Evaluating prompt 4550
Evaluating prompt 4600
Evaluating prompt 4650
Evaluating prompt 4700
Evaluating prompt 4750
Evaluating prompt 4800
Evaluating prompt 4850
Evaluating prompt 4900
Evaluating prompt 4950
Evaluating prompt 5000
Evaluating prompt 5050
Evaluating prompt 5100
Evaluating prompt 5150
Evaluating prompt 5200
Evaluating prompt 5250
Evaluating prompt 5300
Evaluating prompt 5350
Evaluating prompt 5400
Evaluating prompt 5450
Evaluating prompt 5500
Evaluating prompt 5550
Evaluating prompt 5600
Evaluating prompt 5650
Evaluating prompt 5700
Evaluating prompt 5750
Evaluating prompt 5800
Evaluating prompt 5850
Evaluating prompt 5900
Evaluating prompt 5950
Evaluating prompt 6000
Evaluating prompt 6050
Evaluating prompt 6100
Evaluating prompt 6150
Evaluating prompt 6200
Evaluating prompt 6250
Evaluating prompt 6300
Evaluating prompt 6350
Evaluating prompt 6400
Evaluating prompt 6450
Evaluating prompt 6500
Evaluating prompt 6550
Evaluating prompt 6600
Evaluating prompt 6650
Evaluating prompt 6700
Evaluating prompt 6750
Evaluating prompt 6800
Evaluating prompt 6850
Evaluating prompt 6900
Evaluating prompt 6950
Evaluating prompt 7000
Evaluating prompt 7050
Evaluating prompt 7100
Evaluating prompt 7150
Evaluating prompt 7200
Evaluating prompt 7250
Evaluating prompt 7300
Evaluating prompt 7350
Evaluating prompt 7400
Evaluating prompt 7450
Evaluating prompt 7500
Evaluating prompt 7550
Evaluating prompt 7600
Evaluating prompt 7650
Evaluating prompt 7700
Evaluating prompt 7750
Evaluating prompt 7800
Evaluating prompt 7850
Evaluating prompt 7900
Evaluating prompt 7950
Evaluating prompt 8000
Evaluating prompt 8050
Evaluating prompt 8100
Evaluating prompt 8150
Evaluating prompt 8200
Evaluating prompt 8250
Evaluating prompt 8300
Evaluating prompt 8350
Evaluating prompt 8400
Evaluating prompt 8450
Evaluating prompt 8500
Evaluating prompt 8550
Evaluating prompt 8600
Evaluating prompt 8650
Evaluating prompt 8700
Evaluating prompt 8750
Evaluating prompt 8800
Evaluating prompt 8850
Evaluating prompt 8900
Evaluating prompt 8950
Evaluating prompt 9000
Evaluating prompt 9050
Evaluating prompt 9100
Evaluating prompt 9150
Evaluating prompt 9200
Evaluating prompt 9250
Evaluating prompt 9300
Evaluating prompt 9350
Evaluating prompt 9400
Evaluating prompt 9450
Evaluating prompt 9500
Evaluating prompt 9550
Evaluating prompt 9600
Evaluating prompt 9650
Evaluating prompt 9700
Evaluating prompt 9750
Evaluating prompt 9800
Evaluating prompt 9850
Evaluating prompt 9900
Evaluating prompt 9950
Performance after mean ablating attention head (11, 10): 0.9436
Code
print(all_performance_scores_mean_ablation)

np.save("all_performance_scores_mean_ablation.npy", all_performance_scores_mean_ablation)
[0.712, 0.7912, 0.8386, 0.7772, 0.7141, 0.9436]
Code
plt.title("Model performance after mean ablation of attention head outliers")
plt.xlabel("Ablated attention head position")
plt.ylabel("Model performance on IOI tasks")

xs = ["None"] + [str(t) for t in attn_outliers]
baseline_performance = 0.7864
ys = [baseline_performance] + all_performance_scores_mean_ablation

plt.bar(xs, ys)
plt.show()

Code
# Correlation between difference in attribution scores and difference in performance

mean_performance_differences = []
ig_outlier_scores = []
ap_outlier_scores = []
score_differences = []

for i in range(len(all_performance_scores_mean_ablation)):
    mean_performance_differences.append(all_performance_scores_mean_ablation[i] - baseline_performance)
    layer, attn_idx = attn_outliers[i]
    ap_outlier_scores.append(attn_patch_results[layer, attn_idx])
    ig_outlier_scores.append(attn_ig_results[layer, attn_idx])
    score_diff = attn_patch_results[layer, attn_idx] - attn_ig_results[layer, attn_idx]
    score_differences.append(score_diff)
Code
# Correlation between attribution scores and performance change

sns.regplot(x=ig_outlier_scores, y=mean_performance_differences)
plt.ylabel("Difference between performance and baseline performance under mean ablation")
plt.xlabel("IG attribution scores")
plt.show()

print(f"Correlation coefficient between IG attribution score and performance score difference under mean ablation: {np.corrcoef(ig_outlier_scores, mean_performance_differences)[0, 1]}")

sns.regplot(x=ap_outlier_scores, y=mean_performance_differences)
plt.ylabel("Difference between performance and baseline performance")
plt.xlabel("Activation Patching attribution scores")
plt.show()

print(f"Correlation coefficient between IG attribution score and performance score difference under mean ablation: {np.corrcoef(ap_outlier_scores, mean_performance_differences)[0, 1]}")

Correlation coefficient between IG attribution score and performance score difference under mean ablation: 0.7868078399469188

Correlation coefficient between IG attribution score and performance score difference under mean ablation: 0.8659126942995752

Heads for different tasks

Hypothesis: outliers highlighted by IG but not by AP may not be important to the specific task, but they may affect performance on other tasks instead.

  • IG calculates change in loss based on integrating gradients between two input values.
  • A high attribution score could be caused by strong gradients (sensitivity) up until an intermediate input value (in between the two input values). In this case, the highlighted component would be important for the task “in between” (represented by different counterfactual inputs) instead of the target task.

To test this, we can visualise the gradients for intervals which are summed up by IG. We focus on attention head (9, 6) because it is highlighted more strongly by IG than by AP.

Visualising gradients

Code
# Calculate attribution score based on mean over each embedding, for each token
def mean_attribution(attribution_scores, pos=None):
    per_token_score = attribution_scores.mean(dim=3)
    score = per_token_score.mean(dim=1)
    if pos is None:
        return score
    return score[:, pos]
Code
%load_ext autoreload
%autoreload 2
Code
%reload_ext autoreload

from integrated_gradients import CustomLayerIntegratedGradients

n_steps = 50

def attn_gradient_attribution_history(target_layer_num, target_pos):

    # Gradient attribution on heads
    hook_name = get_act_name("result", target_layer_num)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("z", target_layer_num)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook] # Baseline
    layer_corrupt_input = corrupted_cache[prev_layer_hook] # Input

    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(x, clean_input, prev_layer)
    # Attribute to the target_layer's output
    ig_embed = CustomLayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)

    attributions, gradient_history, alphas = ig_embed.attribute(inputs=layer_corrupt_input,
                                    baselines=layer_clean_input, 
                                    n_steps=n_steps,
                                    attribute_to_layer_input=False,
                                    return_convergence_delta=False,
                                    return_gradient_history=True)

    # Unwrap single value from tuple and scale
    gradient_history = gradient_history[0]

    target_result = mean_attribution(attributions, pos=target_pos)
    mean_gradient_history = mean_attribution(gradient_history, pos=target_pos)

    return target_result, mean_gradient_history, alphas
Code
def visualise_attn_grad_attribution(target_layer_num, target_pos):
    # Gradient attribution on target head (9, 6)
    score, gradient_history, alphas = attn_gradient_attribution_history(target_layer_num, target_pos)

    score = score.detach().cpu()
    gradient_history = gradient_history.detach().cpu()

    print(f"Attribution score for head {(target_layer_num, target_pos)}: {score}")

    # Sanity check: attribution scores should match
    print("Matches previous IG score: ", np.isclose(attn_ig_results[target_layer_num, target_pos], score))

    plt.title(f"Head {(target_layer_num, target_pos)}: Gradient magnitude between baseline and input for IG")
    plt.plot(alphas, gradient_history)
    plt.xlabel("Interpolation index")
    plt.ylabel("Gradient magnitude")
    plt.show()

    # Idenitfy interpolation turning point
    print(gradient_history)
    turning_point = torch.where(torch.abs(gradient_history) > 1e-5)[0]
    print(f"Interpolation turning point: {turning_point}")
Code
# Gradient magnitude for head exclusively highlighted by IG as important

visualise_attn_grad_attribution(9, 6)
Running attribution!
Input: (tensor([[[[-3.8202e-02, -1.4827e-01,  1.0766e-01,  ...,  1.3316e-01,
           -9.8657e-02, -6.1993e-02],
          [ 1.7961e-02, -7.9668e-02, -4.2784e-02,  ...,  1.6037e-03,
            2.6998e-02,  5.7171e-02],
          [ 4.7301e-02, -2.6343e-02, -8.6350e-02,  ...,  4.2694e-02,
           -6.2673e-02, -7.3593e-03],
          ...,
          [ 5.6734e-02,  1.3328e-02,  1.5833e-03,  ..., -1.0371e-01,
            3.4959e-02,  3.3161e-02],
          [-1.3501e-01,  2.5231e-01,  1.2611e-01,  ..., -2.9068e-02,
           -1.3579e-01,  7.7787e-02],
          [ 3.4632e-02,  2.9779e-02, -5.6488e-04,  ...,  3.8829e-02,
           -4.0786e-02,  1.3188e-01]],

         [[-4.2960e-02, -1.4648e-01,  1.1146e-01,  ...,  1.3758e-01,
           -9.8202e-02, -5.6699e-02],
          [ 2.0101e-02, -9.2041e-02, -4.7995e-02,  ..., -7.4466e-03,
            2.3998e-02,  5.9188e-02],
          [ 6.2252e-02, -1.3594e-02, -8.9068e-02,  ...,  5.0048e-02,
           -6.6885e-02,  7.8406e-03],
          ...,
          [ 5.8222e-02,  1.4373e-02,  5.5309e-03,  ..., -1.0639e-01,
            3.1540e-02,  3.6838e-02],
          [-1.3131e-01,  2.4665e-01,  1.3239e-01,  ..., -4.0621e-02,
           -1.5014e-01,  5.5906e-02],
          [ 3.5457e-02,  3.1159e-02, -6.4574e-04,  ...,  3.9075e-02,
           -4.1930e-02,  1.3077e-01]],

         [[-4.0339e-02, -1.9827e-01,  1.2731e-01,  ...,  8.5587e-02,
           -1.4614e-01, -4.5109e-02],
          [ 1.7036e-02, -8.4320e-02, -4.3065e-02,  ...,  4.8828e-04,
            2.5040e-02,  5.5996e-02],
          [ 4.7582e-02, -3.9674e-02, -1.1146e-01,  ...,  4.1864e-02,
           -3.3406e-02,  1.5174e-02],
          ...,
          [ 6.1540e-02,  9.5445e-03, -1.0269e-05,  ..., -1.0607e-01,
            3.4504e-02,  3.2050e-02],
          [-1.2639e-01,  2.4143e-01,  1.3781e-01,  ..., -5.2199e-02,
           -1.6489e-01,  3.4250e-02],
          [ 3.4365e-02,  2.8884e-02,  6.8105e-05,  ...,  3.9877e-02,
           -4.2655e-02,  1.2899e-01]],

         ...,

         [[-8.2595e-02, -1.9998e-01,  1.5490e+00,  ...,  3.3959e-01,
            1.4255e-01,  6.5081e-01],
          [ 1.1554e-02, -4.9596e-02, -3.0127e-02,  ..., -4.3610e-03,
            4.8181e-03,  5.1586e-02],
          [ 1.4618e-01, -1.2103e-01,  6.3858e-02,  ...,  1.0316e-01,
           -2.4471e-01, -8.5784e-02],
          ...,
          [ 8.2703e-02,  6.1125e-03,  1.1424e-03,  ..., -1.1052e-01,
            3.7004e-02,  4.8783e-02],
          [-9.3220e-02,  2.7190e-01,  1.3811e-01,  ..., -5.3811e-04,
           -1.6357e-01,  8.1474e-02],
          [ 5.6859e-02, -2.5700e-02,  3.6478e-02,  ...,  6.3215e-02,
           -6.2333e-02,  1.0075e-01]],

         [[-3.0842e-02, -1.8030e-01,  3.5580e-01,  ...,  1.7976e-01,
           -1.5684e-01,  7.9073e-02],
          [ 1.8380e-02, -6.1629e-02, -3.7418e-02,  ..., -1.2335e-02,
            1.8126e-02,  4.8978e-02],
          [ 9.5952e-02, -3.6483e-02, -3.4110e-02,  ...,  8.2476e-02,
           -1.3885e-01, -5.9304e-03],
          ...,
          [ 5.7276e-02,  1.3689e-02,  1.6319e-02,  ..., -1.0419e-01,
            4.6597e-02,  5.2112e-02],
          [ 6.9066e-01,  6.7834e-01,  5.1696e-01,  ...,  3.9500e-01,
           -6.4991e-01,  2.7411e-01],
          [ 5.0142e-02,  1.7234e-02,  1.9676e-02,  ...,  5.7096e-02,
           -8.4455e-02,  9.9966e-02]],

         [[-1.5154e-01, -4.0557e-01,  2.0662e-01,  ..., -1.0168e-01,
           -2.6156e-01,  6.3641e-02],
          [ 1.9388e-02, -2.3396e-01,  7.4004e-02,  ..., -2.8685e-02,
           -1.8736e-02,  2.0577e-02],
          [ 1.4966e-01, -1.5743e-01,  9.5336e-02,  ...,  1.3301e-01,
           -4.4521e-01,  9.4549e-02],
          ...,
          [ 3.3283e-01, -3.2398e-01, -2.7622e-01,  ..., -8.3556e-01,
           -1.4135e+00,  4.9114e-01],
          [-3.9389e-02,  2.9505e-01,  1.6275e-01,  ..., -3.3018e-03,
           -1.9343e-01,  8.3654e-02],
          [ 1.8233e-01, -2.0033e-01,  1.1411e-01,  ...,  1.0744e-01,
           -1.3532e-01, -1.7710e-02]]]], device='mps:0', requires_grad=True),)
Inputs layer: (tensor([[[[ 6.7257e-02,  2.1852e-02,  7.4130e-02,  ..., -4.0783e-02,
            1.4342e-01,  1.3830e-01],
          [ 1.8222e-02, -1.1721e-02,  2.2810e-02,  ...,  1.1000e-02,
           -7.0729e-02, -8.8895e-03],
          [-4.9888e-02, -5.0902e-03,  3.6127e-02,  ..., -5.3131e-02,
            1.1598e-01, -6.6142e-02],
          ...,
          [-1.1046e-02, -1.0265e-01,  3.1398e-02,  ...,  6.6149e-02,
            6.3570e-02, -8.9751e-02],
          [-1.8324e-01,  6.2830e-02,  4.6200e-02,  ...,  1.7861e-01,
           -8.4179e-02, -5.3105e-02],
          [-5.5550e-02,  2.6784e-01, -1.6964e-01,  ..., -2.2132e-02,
            2.0114e-01,  4.2940e-02]],

         [[ 6.7305e-02,  2.3125e-02,  6.8377e-02,  ..., -4.0580e-02,
            1.3791e-01,  1.4491e-01],
          [ 9.7610e-03,  2.1347e-04,  1.3024e-02,  ...,  9.4734e-03,
           -7.2768e-02, -1.4875e-02],
          [-4.4417e-02, -7.9859e-03,  3.4506e-02,  ..., -5.8396e-02,
            9.1113e-02, -6.8698e-02],
          ...,
          [-9.1560e-03, -9.9715e-02,  2.6937e-02,  ...,  7.1029e-02,
            6.8079e-02, -9.4438e-02],
          [-1.6535e-01,  6.1141e-02,  2.8962e-02,  ...,  1.6938e-01,
           -6.5112e-02, -7.1129e-02],
          [-5.4615e-02,  2.6768e-01, -1.6892e-01,  ..., -2.1316e-02,
            2.0000e-01,  4.3236e-02]],

         [[ 1.6413e-02,  2.6860e-03,  3.5159e-02,  ..., -1.7897e-02,
            5.8670e-02,  1.1333e-01],
          [ 1.8434e-02, -1.0882e-02,  1.9270e-02,  ...,  1.0286e-02,
           -7.4354e-02, -1.2423e-02],
          [-6.2168e-02,  1.0135e-02,  4.4744e-02,  ..., -6.3592e-02,
            1.1222e-01, -9.8139e-02],
          ...,
          [-1.4830e-02, -1.0138e-01,  3.1198e-02,  ...,  6.5358e-02,
            7.1195e-02, -9.2705e-02],
          [-1.4656e-01,  5.7309e-02,  1.2572e-02,  ...,  1.6178e-01,
           -4.8725e-02, -9.1210e-02],
          [-5.3958e-02,  2.6912e-01, -1.6539e-01,  ..., -2.2252e-02,
            1.9970e-01,  4.1960e-02]],

         ...,

         [[-6.5542e-01,  7.3925e-01,  3.5380e-01,  ...,  4.7193e-01,
           -1.0136e-01,  1.0693e-01],
          [ 3.4425e-02, -1.6041e-02, -4.5687e-03,  ...,  3.1169e-02,
           -8.3258e-02, -8.0656e-03],
          [-2.4455e-02, -3.1810e-02, -5.1973e-02,  ..., -5.2291e-03,
            3.9967e-02, -1.8897e-01],
          ...,
          [-2.3257e-02, -9.1931e-02,  2.3208e-02,  ...,  5.8166e-02,
            7.9147e-02, -8.4748e-02],
          [-1.6788e-01,  9.8345e-03,  3.8483e-02,  ...,  1.8770e-01,
           -7.3330e-02, -4.1527e-02],
          [-7.0851e-02,  2.8489e-01, -1.5013e-01,  ..., -2.6787e-03,
            1.7639e-01,  1.6531e-02]],

         [[-8.3777e-02,  1.5804e-01,  9.2494e-02,  ...,  3.0265e-02,
            1.4602e-01,  1.2322e-01],
          [ 1.7982e-02, -9.3206e-03, -5.2507e-03,  ...,  1.5176e-02,
           -7.3022e-02, -1.3481e-02],
          [-5.2448e-02, -1.5984e-02,  7.6266e-03,  ..., -8.1312e-03,
            5.1856e-02, -1.0552e-01],
          ...,
          [-2.5957e-02, -1.0190e-01,  1.0791e-02,  ...,  5.9728e-02,
            1.0416e-01, -8.8054e-02],
          [ 1.3211e-01, -1.0156e+00, -1.5681e-01,  ...,  4.9979e-01,
            1.8641e-01,  2.7841e-01],
          [-8.6638e-02,  2.6933e-01, -1.7620e-01,  ..., -4.6207e-03,
            2.0693e-01,  1.3487e-02]],

         [[-6.0754e-02, -7.4416e-02,  2.6250e-01,  ...,  3.2920e-02,
           -8.5164e-02, -4.1713e-02],
          [ 1.4581e-02,  1.7219e-02, -3.4395e-02,  ...,  1.5486e-01,
            6.6413e-03, -1.0352e-01],
          [ 2.3665e-03, -1.1250e-01, -2.0316e-01,  ...,  1.0723e-01,
            1.6910e-04, -2.5264e-01],
          ...,
          [ 1.0065e+00, -6.3978e-01, -4.6803e-01,  ..., -5.7909e-01,
           -1.0093e+00,  2.4856e+00],
          [-1.5921e-01, -3.6874e-02, -9.1598e-03,  ...,  1.7008e-01,
           -6.8307e-02,  2.2644e-02],
          [-1.2472e-02,  4.0908e-01, -1.6936e-02,  ...,  2.9166e-02,
            4.5930e-02, -9.6726e-02]]]], device='mps:0'),)
Baseline input: (tensor([[[[-3.8202e-02, -1.4827e-01,  1.0766e-01,  ...,  1.3316e-01,
           -9.8657e-02, -6.1993e-02],
          [ 1.7961e-02, -7.9668e-02, -4.2784e-02,  ...,  1.6037e-03,
            2.6998e-02,  5.7171e-02],
          [ 4.7301e-02, -2.6343e-02, -8.6350e-02,  ...,  4.2694e-02,
           -6.2673e-02, -7.3593e-03],
          ...,
          [ 5.6734e-02,  1.3328e-02,  1.5833e-03,  ..., -1.0371e-01,
            3.4959e-02,  3.3161e-02],
          [-1.3501e-01,  2.5231e-01,  1.2611e-01,  ..., -2.9068e-02,
           -1.3579e-01,  7.7787e-02],
          [ 3.4632e-02,  2.9779e-02, -5.6488e-04,  ...,  3.8829e-02,
           -4.0786e-02,  1.3188e-01]],

         [[-4.2960e-02, -1.4648e-01,  1.1146e-01,  ...,  1.3758e-01,
           -9.8202e-02, -5.6699e-02],
          [ 2.0101e-02, -9.2041e-02, -4.7995e-02,  ..., -7.4466e-03,
            2.3998e-02,  5.9188e-02],
          [ 6.2252e-02, -1.3594e-02, -8.9068e-02,  ...,  5.0048e-02,
           -6.6885e-02,  7.8406e-03],
          ...,
          [ 5.8222e-02,  1.4373e-02,  5.5309e-03,  ..., -1.0639e-01,
            3.1540e-02,  3.6838e-02],
          [-1.3131e-01,  2.4665e-01,  1.3239e-01,  ..., -4.0621e-02,
           -1.5014e-01,  5.5906e-02],
          [ 3.5457e-02,  3.1159e-02, -6.4574e-04,  ...,  3.9075e-02,
           -4.1930e-02,  1.3077e-01]],

         [[-4.0339e-02, -1.9827e-01,  1.2731e-01,  ...,  8.5587e-02,
           -1.4614e-01, -4.5109e-02],
          [ 1.7036e-02, -8.4320e-02, -4.3065e-02,  ...,  4.8828e-04,
            2.5040e-02,  5.5996e-02],
          [ 4.7582e-02, -3.9674e-02, -1.1146e-01,  ...,  4.1864e-02,
           -3.3406e-02,  1.5174e-02],
          ...,
          [ 6.1540e-02,  9.5445e-03, -1.0269e-05,  ..., -1.0607e-01,
            3.4504e-02,  3.2050e-02],
          [-1.2639e-01,  2.4143e-01,  1.3781e-01,  ..., -5.2199e-02,
           -1.6489e-01,  3.4250e-02],
          [ 3.4365e-02,  2.8884e-02,  6.8105e-05,  ...,  3.9877e-02,
           -4.2655e-02,  1.2899e-01]],

         ...,

         [[-6.4666e-02, -2.2279e-01,  1.4367e+00,  ...,  3.0079e-01,
            9.3913e-02,  5.7781e-01],
          [ 1.4685e-02, -4.4217e-02, -2.7132e-02,  ..., -1.3492e-03,
            1.2704e-02,  5.5928e-02],
          [ 1.4704e-01, -1.1678e-01,  6.8219e-02,  ...,  9.6129e-02,
           -2.7377e-01, -9.1615e-02],
          ...,
          [ 8.0699e-02,  8.0104e-03,  3.3862e-03,  ..., -1.1085e-01,
            3.6471e-02,  5.0242e-02],
          [-9.3277e-02,  2.6532e-01,  1.3590e-01,  ...,  6.0855e-04,
           -1.6453e-01,  8.7835e-02],
          [ 6.1304e-02, -2.9531e-02,  4.1058e-02,  ...,  6.2705e-02,
           -6.3086e-02,  1.0173e-01]],

         [[-2.6884e-02, -1.8591e-01,  3.3385e-01,  ...,  1.7805e-01,
           -1.5692e-01,  6.5920e-02],
          [ 1.9038e-02, -5.7073e-02, -3.5611e-02,  ..., -1.0744e-02,
            2.1955e-02,  5.1898e-02],
          [ 9.1345e-02, -3.4546e-02, -3.6675e-02,  ...,  7.9345e-02,
           -1.3974e-01, -1.0724e-02],
          ...,
          [ 5.6637e-02,  1.3481e-02,  1.8984e-02,  ..., -1.0618e-01,
            4.8430e-02,  5.3521e-02],
          [ 7.4771e-01,  6.1334e-01,  4.9233e-01,  ...,  5.3850e-01,
           -6.4046e-01,  4.2198e-01],
          [ 5.6981e-02,  1.5577e-02,  2.4973e-02,  ...,  6.1298e-02,
           -9.0802e-02,  1.0218e-01]],

         [[-1.7451e-02, -3.1113e-01,  2.0889e-01,  ..., -2.4565e-02,
           -2.5259e-01,  5.4470e-02],
          [-1.1172e-02, -2.2590e-01, -1.4908e-04,  ..., -4.5345e-02,
           -4.1357e-02,  5.9049e-03],
          [ 9.3422e-02, -5.8440e-02, -2.7259e-02,  ...,  7.9603e-02,
           -1.8869e-01,  1.2196e-02],
          ...,
          [ 1.8930e+00, -1.6141e+00, -9.7022e-01,  ..., -8.6126e-01,
            9.9103e-02, -6.9525e-01],
          [ 6.1128e-03,  2.8894e-01,  1.4076e-01,  ...,  4.6594e-02,
           -2.0995e-01,  1.0698e-01],
          [ 2.0692e-01, -2.3148e-01,  1.4961e-01,  ...,  1.1938e-01,
           -1.5921e-01, -2.1903e-02]]]], device='mps:0', requires_grad=True),)
Baselines layer: (tensor([[[[ 6.7257e-02,  2.1852e-02,  7.4130e-02,  ..., -4.0783e-02,
            1.4342e-01,  1.3830e-01],
          [ 1.8222e-02, -1.1721e-02,  2.2810e-02,  ...,  1.1000e-02,
           -7.0729e-02, -8.8895e-03],
          [-4.9888e-02, -5.0902e-03,  3.6127e-02,  ..., -5.3131e-02,
            1.1598e-01, -6.6142e-02],
          ...,
          [-1.1046e-02, -1.0265e-01,  3.1398e-02,  ...,  6.6149e-02,
            6.3570e-02, -8.9751e-02],
          [-1.8324e-01,  6.2830e-02,  4.6200e-02,  ...,  1.7861e-01,
           -8.4179e-02, -5.3105e-02],
          [-5.5550e-02,  2.6784e-01, -1.6964e-01,  ..., -2.2132e-02,
            2.0114e-01,  4.2940e-02]],

         [[ 6.7305e-02,  2.3125e-02,  6.8377e-02,  ..., -4.0580e-02,
            1.3791e-01,  1.4491e-01],
          [ 9.7610e-03,  2.1347e-04,  1.3024e-02,  ...,  9.4734e-03,
           -7.2768e-02, -1.4875e-02],
          [-4.4417e-02, -7.9859e-03,  3.4506e-02,  ..., -5.8396e-02,
            9.1113e-02, -6.8698e-02],
          ...,
          [-9.1560e-03, -9.9715e-02,  2.6937e-02,  ...,  7.1029e-02,
            6.8079e-02, -9.4438e-02],
          [-1.6535e-01,  6.1141e-02,  2.8962e-02,  ...,  1.6938e-01,
           -6.5112e-02, -7.1129e-02],
          [-5.4615e-02,  2.6768e-01, -1.6892e-01,  ..., -2.1316e-02,
            2.0000e-01,  4.3236e-02]],

         [[ 1.6413e-02,  2.6860e-03,  3.5159e-02,  ..., -1.7897e-02,
            5.8670e-02,  1.1333e-01],
          [ 1.8434e-02, -1.0882e-02,  1.9270e-02,  ...,  1.0286e-02,
           -7.4354e-02, -1.2423e-02],
          [-6.2168e-02,  1.0135e-02,  4.4744e-02,  ..., -6.3592e-02,
            1.1222e-01, -9.8139e-02],
          ...,
          [-1.4830e-02, -1.0138e-01,  3.1198e-02,  ...,  6.5358e-02,
            7.1195e-02, -9.2705e-02],
          [-1.4656e-01,  5.7309e-02,  1.2572e-02,  ...,  1.6178e-01,
           -4.8725e-02, -9.1210e-02],
          [-5.3958e-02,  2.6912e-01, -1.6539e-01,  ..., -2.2252e-02,
            1.9970e-01,  4.1960e-02]],

         ...,

         [[-6.2515e-01,  6.9672e-01,  3.5210e-01,  ...,  4.6784e-01,
           -1.0845e-01,  5.3174e-02],
          [ 3.6308e-02, -1.6594e-02,  1.4097e-03,  ...,  4.1001e-02,
           -6.7964e-02, -6.1845e-03],
          [-1.4689e-02, -4.3328e-02, -6.1983e-02,  ..., -1.4398e-03,
            4.2181e-02, -1.8774e-01],
          ...,
          [-2.1084e-02, -9.3169e-02,  2.3712e-02,  ...,  6.0094e-02,
            7.7967e-02, -8.1557e-02],
          [-1.7266e-01,  1.2970e-02,  4.1380e-02,  ...,  1.9077e-01,
           -7.0304e-02, -3.9218e-02],
          [-7.0699e-02,  2.8755e-01, -1.4869e-01,  ..., -1.1955e-03,
            1.7067e-01,  1.5049e-02]],

         [[-7.5672e-02,  1.5290e-01,  8.8990e-02,  ...,  3.2046e-02,
            1.3982e-01,  1.1454e-01],
          [ 2.0688e-02, -1.0264e-02, -3.5922e-03,  ...,  2.3380e-02,
           -6.1244e-02, -1.3335e-02],
          [-5.0905e-02, -1.6117e-02,  9.7898e-03,  ..., -1.3327e-02,
            5.7764e-02, -1.0200e-01],
          ...,
          [-2.6266e-02, -9.9956e-02,  1.2263e-02,  ...,  6.0819e-02,
            1.0602e-01, -8.8428e-02],
          [ 7.3145e-02, -1.0733e+00, -1.1518e-01,  ...,  5.8111e-01,
            2.5014e-01,  2.9084e-01],
          [-8.6538e-02,  2.7388e-01, -1.7713e-01,  ..., -3.3255e-03,
            2.0242e-01,  9.7987e-03]],

         [[-6.8028e-02,  2.0828e-02,  7.9210e-02,  ...,  2.9539e-02,
           -8.7750e-02, -2.8678e-02],
          [ 3.0150e-02,  1.6722e-02, -8.9075e-02,  ...,  5.4656e-02,
           -1.4116e-01, -1.2706e-01],
          [-2.3913e-02, -3.1963e-02, -2.1303e-02,  ..., -1.7024e-02,
            7.7753e-02, -1.5626e-01],
          ...,
          [-1.6762e+00,  1.4804e-01,  2.9628e-01,  ..., -6.5625e-01,
            2.7657e+00, -8.3868e-01],
          [-1.8645e-01, -9.2984e-02, -6.1550e-03,  ...,  2.0976e-01,
           -8.0872e-02,  5.3476e-02],
          [-6.3047e-03,  4.6148e-01,  4.4118e-02,  ...,  3.5672e-02,
           -4.6170e-02, -1.3669e-01]]]], device='mps:0'),)
Attribution score for head (9, 6): tensor([-4.5160e-07])
Matches previous IG score:  [ True]

tensor([-0.0003,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000])
Interpolation turning point: tensor([0])
Code
# Gradient magnitude for attention heads highlighted by both methods

visualise_attn_grad_attribution(7, 3)
visualise_attn_grad_attribution(10, 7)
Running attribution!
Attribution score for head (7, 3): tensor([-1.5179e-06])
Matches previous IG score:  [ True]

tensor([-0.0010,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000])
Interpolation turning point: tensor([0])
Running attribution!
Attribution score for head (10, 7): tensor([3.8402e-06])
Matches previous IG score:  [ True]

tensor([0.0026, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
Interpolation turning point: tensor([0])
Code
# Attention head with very low attribution scores in both methods

visualise_attn_grad_attribution(1, 7)
Running attribution!
Attribution score for head (1, 7): tensor([-2.0696e-08])
Matches previous IG score:  [ True]

For all attention heads, the gradient magnitude seems to increase or decrease sharply (until the 9th interpolation step) then plateau at 0. The magnitude of the attribution score depends on the size of the range in gradient magnitudes.

  • Possibly the sharp corner point is the minimal contrastive example for which these attention heads light up? Motivation to find optimal contrastive baseline/input pair to obtain the same or greater attribution scores, as a way to identify behaviour.
    • Note that even though all the turning points are at the same interpolation index, this does not necessarily mean that they all have the same optimal counterfactual pairing, because we are interpolating along different activations for each attention head.
    • Can we reverse engineer the activations based on this interpolation point?
  • Doesn’t explain why IG highlights some attention heads which ablation does not highlight

Visualising outputs

I am not fully confident that my implementation of visualising the gradients is correct. Here, we visualise the outputs with respect to interpolated layer inputs for clarity.

Code
from captum.attr._utils.approximation_methods import approximation_parameters

def visualise_attn_interpolated_outputs(target_layer_num, target_pos):
    hook_name = get_act_name("result", target_layer_num)
    target_layer = model.hook_dict[hook_name]

    layer_clean_input = clean_cache[hook_name] # Baseline

    # Only corrupt at target head
    layer_corrupt_input = layer_clean_input.clone()
    layer_corrupt_input[:, :, target_pos] = corrupted_cache[hook_name][:, :, target_pos]

    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(x, clean_input, target_layer)
    _, alphas_func = approximation_parameters("gausslegendre")
    alphas = alphas_func(n_steps)

    with torch.autograd.set_grad_enabled(True):
        interpolated_inputs = [layer_clean_input + alpha * (layer_corrupt_input - layer_clean_input) for alpha in alphas]
        outputs = [forward_fn(i) for i in interpolated_inputs]

    print(outputs)

    plt.title(f"Model output at interpolated gradients: head {(target_layer_num, target_pos)}")
    plt.plot([o.item() for o in outputs])
    plt.xlabel("Interpolation coefficient")
    plt.ylabel("Output (logit difference)")
    plt.ylim(0, 6)
    plt.show()
Code
# Highlighted only by IG

visualise_attn_interpolated_outputs(9, 6)

Code
visualise_attn_interpolated_outputs(0, 2)

The attention head for which only integrated gradients highlights (weakly) has a very low gradient, i.e. the outputs do not change significantly when the input is interpolated. This is in contrast to other attention heads which have stronger attribution scores in both methods; these methods have larger gradients. The change in gradients is still larger than attention heads which have negligible attribution scores (these have flat output gradients).

It may just be that IG is more error-prone as it over-estimates the importance of components for which outputs fluctuate slightly.

Code
# Highlighted by both, strong impact on performance under ablation

visualise_attn_interpolated_outputs(7, 3)
visualise_attn_interpolated_outputs(10, 6)

Code
# Highlighted by both, positive effect on performance under ablation

visualise_attn_interpolated_outputs(11, 10)

Code
# Highlighted in both methods, lack of ablation effect

visualise_attn_interpolated_outputs(9, 9)
visualise_attn_interpolated_outputs(8, 10)

Code
# Low attribution scores in both methods

visualise_attn_interpolated_outputs(2, 5)
visualise_attn_interpolated_outputs(10, 3)
[tensor([4.2764], grad_fn=<UnsqueezeBackward0>), tensor([4.2765], grad_fn=<UnsqueezeBackward0>), tensor([4.2765], grad_fn=<UnsqueezeBackward0>), tensor([4.2766], grad_fn=<UnsqueezeBackward0>), tensor([4.2766], grad_fn=<UnsqueezeBackward0>), tensor([4.2767], grad_fn=<UnsqueezeBackward0>), tensor([4.2768], grad_fn=<UnsqueezeBackward0>), tensor([4.2770], grad_fn=<UnsqueezeBackward0>), tensor([4.2771], grad_fn=<UnsqueezeBackward0>), tensor([4.2772], grad_fn=<UnsqueezeBackward0>), tensor([4.2774], grad_fn=<UnsqueezeBackward0>), tensor([4.2776], grad_fn=<UnsqueezeBackward0>), tensor([4.2778], grad_fn=<UnsqueezeBackward0>), tensor([4.2780], grad_fn=<UnsqueezeBackward0>), tensor([4.2782], grad_fn=<UnsqueezeBackward0>), tensor([4.2784], grad_fn=<UnsqueezeBackward0>), tensor([4.2787], grad_fn=<UnsqueezeBackward0>), tensor([4.2789], grad_fn=<UnsqueezeBackward0>), tensor([4.2792], grad_fn=<UnsqueezeBackward0>), tensor([4.2795], grad_fn=<UnsqueezeBackward0>), tensor([4.2797], grad_fn=<UnsqueezeBackward0>), tensor([4.2800], grad_fn=<UnsqueezeBackward0>), tensor([4.2803], grad_fn=<UnsqueezeBackward0>), tensor([4.2806], grad_fn=<UnsqueezeBackward0>), tensor([4.2809], grad_fn=<UnsqueezeBackward0>), tensor([4.2812], grad_fn=<UnsqueezeBackward0>), tensor([4.2814], grad_fn=<UnsqueezeBackward0>), tensor([4.2817], grad_fn=<UnsqueezeBackward0>), tensor([4.2820], grad_fn=<UnsqueezeBackward0>), tensor([4.2823], grad_fn=<UnsqueezeBackward0>), tensor([4.2826], grad_fn=<UnsqueezeBackward0>), tensor([4.2829], grad_fn=<UnsqueezeBackward0>), tensor([4.2831], grad_fn=<UnsqueezeBackward0>), tensor([4.2834], grad_fn=<UnsqueezeBackward0>), tensor([4.2836], grad_fn=<UnsqueezeBackward0>), tensor([4.2839], grad_fn=<UnsqueezeBackward0>), tensor([4.2841], grad_fn=<UnsqueezeBackward0>), tensor([4.2843], grad_fn=<UnsqueezeBackward0>), tensor([4.2845], grad_fn=<UnsqueezeBackward0>), tensor([4.2847], grad_fn=<UnsqueezeBackward0>), tensor([4.2848], grad_fn=<UnsqueezeBackward0>), tensor([4.2850], grad_fn=<UnsqueezeBackward0>), tensor([4.2851], grad_fn=<UnsqueezeBackward0>), tensor([4.2853], grad_fn=<UnsqueezeBackward0>), tensor([4.2854], grad_fn=<UnsqueezeBackward0>), tensor([4.2855], grad_fn=<UnsqueezeBackward0>), tensor([4.2856], grad_fn=<UnsqueezeBackward0>), tensor([4.2856], grad_fn=<UnsqueezeBackward0>), tensor([4.2857], grad_fn=<UnsqueezeBackward0>), tensor([4.2857], grad_fn=<UnsqueezeBackward0>)]

[tensor([4.2764], grad_fn=<UnsqueezeBackward0>), tensor([4.2764], grad_fn=<UnsqueezeBackward0>), tensor([4.2763], grad_fn=<UnsqueezeBackward0>), tensor([4.2762], grad_fn=<UnsqueezeBackward0>), tensor([4.2760], grad_fn=<UnsqueezeBackward0>), tensor([4.2758], grad_fn=<UnsqueezeBackward0>), tensor([4.2756], grad_fn=<UnsqueezeBackward0>), tensor([4.2754], grad_fn=<UnsqueezeBackward0>), tensor([4.2751], grad_fn=<UnsqueezeBackward0>), tensor([4.2748], grad_fn=<UnsqueezeBackward0>), tensor([4.2744], grad_fn=<UnsqueezeBackward0>), tensor([4.2741], grad_fn=<UnsqueezeBackward0>), tensor([4.2737], grad_fn=<UnsqueezeBackward0>), tensor([4.2732], grad_fn=<UnsqueezeBackward0>), tensor([4.2728], grad_fn=<UnsqueezeBackward0>), tensor([4.2723], grad_fn=<UnsqueezeBackward0>), tensor([4.2718], grad_fn=<UnsqueezeBackward0>), tensor([4.2713], grad_fn=<UnsqueezeBackward0>), tensor([4.2708], grad_fn=<UnsqueezeBackward0>), tensor([4.2703], grad_fn=<UnsqueezeBackward0>), tensor([4.2697], grad_fn=<UnsqueezeBackward0>), tensor([4.2691], grad_fn=<UnsqueezeBackward0>), tensor([4.2686], grad_fn=<UnsqueezeBackward0>), tensor([4.2680], grad_fn=<UnsqueezeBackward0>), tensor([4.2674], grad_fn=<UnsqueezeBackward0>), tensor([4.2668], grad_fn=<UnsqueezeBackward0>), tensor([4.2663], grad_fn=<UnsqueezeBackward0>), tensor([4.2657], grad_fn=<UnsqueezeBackward0>), tensor([4.2651], grad_fn=<UnsqueezeBackward0>), tensor([4.2646], grad_fn=<UnsqueezeBackward0>), tensor([4.2640], grad_fn=<UnsqueezeBackward0>), tensor([4.2635], grad_fn=<UnsqueezeBackward0>), tensor([4.2629], grad_fn=<UnsqueezeBackward0>), tensor([4.2624], grad_fn=<UnsqueezeBackward0>), tensor([4.2620], grad_fn=<UnsqueezeBackward0>), tensor([4.2615], grad_fn=<UnsqueezeBackward0>), tensor([4.2610], grad_fn=<UnsqueezeBackward0>), tensor([4.2606], grad_fn=<UnsqueezeBackward0>), tensor([4.2602], grad_fn=<UnsqueezeBackward0>), tensor([4.2598], grad_fn=<UnsqueezeBackward0>), tensor([4.2595], grad_fn=<UnsqueezeBackward0>), tensor([4.2592], grad_fn=<UnsqueezeBackward0>), tensor([4.2589], grad_fn=<UnsqueezeBackward0>), tensor([4.2587], grad_fn=<UnsqueezeBackward0>), tensor([4.2584], grad_fn=<UnsqueezeBackward0>), tensor([4.2582], grad_fn=<UnsqueezeBackward0>), tensor([4.2581], grad_fn=<UnsqueezeBackward0>), tensor([4.2580], grad_fn=<UnsqueezeBackward0>), tensor([4.2579], grad_fn=<UnsqueezeBackward0>), tensor([4.2578], grad_fn=<UnsqueezeBackward0>)]

So the attention heads with the attribution scores tend to have peaks and troughs at the input counterfactual and baseline counterfactual respectively. This begs the question - can we visualise how the output changes over a much wider range of interpolation? Are there peaks and troughs outside of this range for the other attention heads, for which IG would higlight given the optimal counterfactual pairs?

  • IDEA: identify the peaks and troughs for the outputs wrt a specific component. Fix the baseline; vary the activation of the specific component up to a different input and check the outputs.

“Optimal” contrastive pairs

Consider head (9, 6), which according to the IOI paper is a name mover head:

“Name Mover Heads output the remaining name. They are active at END, attend to previous names in the sentence, and copy the names they attend to”.

We use a different contrastive pair related to the IOI task, to try and get a high attribution score under IG. We change the corrupted prompt such that 1) the output should change from “John” to “Mary”, and 2) the name copying head (hypothesised role of head 9.6) is even more important.

Code
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After the cat and the dog went to the store, Mary gave a bottle of milk to"

clean_input, corrupted_input = model.to_tokens([clean_prompt, corrupted_prompt])

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

clean_logits, clean_cache = model.run_with_cache(clean_input)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_input)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
Clean logit difference: 1.458
Corrupted logit difference: -0.274
Code
# Gradient attribution on heads
hook_name = get_act_name("result", 9)
target_layer = model.hook_dict[hook_name]
prev_layer_hook = get_act_name("z", 9)
prev_layer = model.hook_dict[prev_layer_hook]

layer_clean_input = clean_cache[prev_layer_hook]
layer_corrupt_input = corrupted_cache[prev_layer_hook]

attributions = compute_layer_to_output_attributions(clean_input, layer_clean_input, layer_corrupt_input, target_layer, prev_layer) # shape [1, seq_len, d_head, d_model]
# Calculate attribution score based on mean over each embedding, for each token
per_token_score = attributions.mean(dim=3)
score = per_token_score.mean(dim=1)

print(score[:,6])

Error (delta) for blocks.9.attn.hook_result attribution: -0.9106521606445312
tensor([3.3620e-08])
Code
# Original scores

print(attn_ig_results[9, 6])
print(attn_patch_results[9, 6])
tensor(-4.5160e-07)
tensor(0.0089)
Code
visualise_attn_interpolated_outputs(9, 6)
[tensor([1.4576], grad_fn=<UnsqueezeBackward0>), tensor([1.4568], grad_fn=<UnsqueezeBackward0>), tensor([1.4553], grad_fn=<UnsqueezeBackward0>), tensor([1.4532], grad_fn=<UnsqueezeBackward0>), tensor([1.4504], grad_fn=<UnsqueezeBackward0>), tensor([1.4471], grad_fn=<UnsqueezeBackward0>), tensor([1.4431], grad_fn=<UnsqueezeBackward0>), tensor([1.4385], grad_fn=<UnsqueezeBackward0>), tensor([1.4333], grad_fn=<UnsqueezeBackward0>), tensor([1.4276], grad_fn=<UnsqueezeBackward0>), tensor([1.4213], grad_fn=<UnsqueezeBackward0>), tensor([1.4146], grad_fn=<UnsqueezeBackward0>), tensor([1.4073], grad_fn=<UnsqueezeBackward0>), tensor([1.3996], grad_fn=<UnsqueezeBackward0>), tensor([1.3914], grad_fn=<UnsqueezeBackward0>), tensor([1.3828], grad_fn=<UnsqueezeBackward0>), tensor([1.3739], grad_fn=<UnsqueezeBackward0>), tensor([1.3646], grad_fn=<UnsqueezeBackward0>), tensor([1.3551], grad_fn=<UnsqueezeBackward0>), tensor([1.3453], grad_fn=<UnsqueezeBackward0>), tensor([1.3352], grad_fn=<UnsqueezeBackward0>), tensor([1.3250], grad_fn=<UnsqueezeBackward0>), tensor([1.3146], grad_fn=<UnsqueezeBackward0>), tensor([1.3041], grad_fn=<UnsqueezeBackward0>), tensor([1.2936], grad_fn=<UnsqueezeBackward0>), tensor([1.2830], grad_fn=<UnsqueezeBackward0>), tensor([1.2725], grad_fn=<UnsqueezeBackward0>), tensor([1.2620], grad_fn=<UnsqueezeBackward0>), tensor([1.2516], grad_fn=<UnsqueezeBackward0>), tensor([1.2413], grad_fn=<UnsqueezeBackward0>), tensor([1.2313], grad_fn=<UnsqueezeBackward0>), tensor([1.2214], grad_fn=<UnsqueezeBackward0>), tensor([1.2118], grad_fn=<UnsqueezeBackward0>), tensor([1.2025], grad_fn=<UnsqueezeBackward0>), tensor([1.1936], grad_fn=<UnsqueezeBackward0>), tensor([1.1849], grad_fn=<UnsqueezeBackward0>), tensor([1.1767], grad_fn=<UnsqueezeBackward0>), tensor([1.1690], grad_fn=<UnsqueezeBackward0>), tensor([1.1617], grad_fn=<UnsqueezeBackward0>), tensor([1.1548], grad_fn=<UnsqueezeBackward0>), tensor([1.1485], grad_fn=<UnsqueezeBackward0>), tensor([1.1428], grad_fn=<UnsqueezeBackward0>), tensor([1.1376], grad_fn=<UnsqueezeBackward0>), tensor([1.1329], grad_fn=<UnsqueezeBackward0>), tensor([1.1289], grad_fn=<UnsqueezeBackward0>), tensor([1.1255], grad_fn=<UnsqueezeBackward0>), tensor([1.1227], grad_fn=<UnsqueezeBackward0>), tensor([1.1206], grad_fn=<UnsqueezeBackward0>), tensor([1.1191], grad_fn=<UnsqueezeBackward0>), tensor([1.1183], grad_fn=<UnsqueezeBackward0>)]

Code
# Get activation patching scores

hook_name = get_act_name("result", 9)
temp_hook = lambda act, hook: patch_attn_hook(act, hook, corrupted_cache, 6)

with model.hooks(fwd_hooks=[(hook_name, temp_hook)]):
    patched_logits = model(clean_input)

patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
# Normalise result by clean and corrupted logit difference
ap_score = (patched_logit_diff - clean_logit_diff) / baseline_diff
print(ap_score)
tensor(-0.0456)

Changing the baseline inputs such that the output gradients vary more doesn’t necessarily seem to affect IG too much, but it does increase the magnitude of the activation patching score.

It seems clear that the baselines used for attribution methods are extremely important hyper-parameters, but there is no clear intuition as to which baseline is “best” for evaluating specific model behaviours. This provides motivation for a new method which identifies the optimal counterfactuals to make attribution methods highlight specific components.